{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zheyuanzhang/anaconda3/envs/FRS/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from RCSYS_utils import *\n",
    "from RCSYS_models import *\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Change here to process another benchmark\n",
    "benchmark_path = '../processed_data/benchmark_macro.pt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.load(benchmark_path)\n",
    "df_user = pd.read_csv('../processed_data/user_tagging.csv')\n",
    "df_food = pd.read_csv('../processed_data/food_tagging.csv')\n",
    "df_fndds = pd.read_csv('../processed_data/fndds.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the mappings\n",
    "gender_dict = {1: 'male', 2: 'female'}\n",
    "race_dict = {\n",
    "    0: 'Missing',\n",
    "    1: 'Mexican American',\n",
    "    2: 'Other Hispanic',\n",
    "    3: 'White',\n",
    "    4: 'Black',\n",
    "    5: 'Other race'\n",
    "}\n",
    "education_dict = {\n",
    "    0: 'Missing',\n",
    "    1: 'Less than 9th grade',\n",
    "    2: '9-11th grade',\n",
    "    3: 'GED or equivalent',\n",
    "    4: 'Some college or AA degree',\n",
    "    5: 'College graduate or above',\n",
    "    7: 'Refused',\n",
    "    9: \"Don't know\"\n",
    "}\n",
    "\n",
    "# Define the function to generate the prompt text\n",
    "def create_prompt(row):\n",
    "    gender = gender_dict[row['gender']]\n",
    "    age = row['age']\n",
    "    race = race_dict[row['race']]\n",
    "    income = row['household_income']\n",
    "    education = education_dict[row['education']]\n",
    "    prompt = f\"User Node {row['SEQN']}: The user information is as follows: {gender}, age {age}, {race}, household income level (the higher the better): {income}, education status: {education}.\"\n",
    "    return prompt\n",
    "\n",
    "# Define the list of nutrition-related columns\n",
    "nutrition_columns = [\n",
    "    'user_low_carb', 'user_low_phosphorus', 'user_low_calorie', 'user_high_calorie',\n",
    "    'user_high_potassium', 'user_low_sodium', 'user_low_cholesterol',\n",
    "    'user_low_saturated_fat', 'user_low_protein', 'user_high_protein',\n",
    "    'user_low_sugar', 'user_high_fiber', 'user_high_iron', 'user_high_folate_acid',\n",
    "    'user_high_vitamin_b12', 'user_high_calcium', 'user_high_vitamin_d', 'user_high_vitamin_c'\n",
    "]\n",
    "\n",
    "# Define the function to generate the health tag prompt\n",
    "def create_health_tag_prompt(row):\n",
    "    tags = []\n",
    "    for col in nutrition_columns:\n",
    "        if row[col] == 1:\n",
    "            tag = col.replace('user_', '').replace('_', ' ')\n",
    "            tags.append(tag)\n",
    "    return ', '.join(tags)\n",
    "\n",
    "# Apply the function to create the new column\n",
    "df_user['prompt_health'] = df_user.apply(create_health_tag_prompt, axis=1)\n",
    "\n",
    "# Apply the function to create the new column\n",
    "df_user['prompt'] = df_user.apply(create_prompt, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_user[['SEQN', 'prompt', 'prompt_health']].to_csv('../processed_data/user_prompt.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "food_nutrition_columns = [\n",
    "    'low_calorie', 'high_calorie', 'low_protein',\n",
    "    'high_protein', 'low_carb', 'high_carb', 'low_sugar', 'high_sugar',\n",
    "    'low_fiber', 'high_fiber', 'low_saturated_fat', 'high_saturated_fat',\n",
    "    'low_cholesterol', 'high_cholesterol', 'low_sodium', 'high_sodium',\n",
    "    'low_calcium', 'high_calcium', 'low_phosphorus', 'high_phosphorus',\n",
    "    'low_potassium', 'high_potassium', 'low_iron', 'high_iron',\n",
    "    'low_folic_acid', 'high_folic_acid', 'low_vitamin_c', 'high_vitamin_c',\n",
    "    'low_vitamin_d', 'high_vitamin_d', 'low_vitamin_b12',\n",
    "    'high_vitamin_b12'\n",
    "]\n",
    "\n",
    "# Define the function to generate the health tag prompt\n",
    "def create_food_tag_prompt(row):\n",
    "    tags = []\n",
    "    for col in food_nutrition_columns:\n",
    "        if row[col] == 1:\n",
    "            tag = col.replace('_', ' ')\n",
    "            tags.append(tag)\n",
    "    return ', '.join(tags)\n",
    "\n",
    "df_food['prompt_health'] = df_food.apply(create_food_tag_prompt, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_food_prompt(group):\n",
    "    food_id = group['food_id'].iloc[0]\n",
    "    food_desc = group['food_desc'].iloc[0]\n",
    "    food_category = group['WWEIA_desc'].iloc[0]\n",
    "    ingredients = ', '.join(group['ingredient_desc'])\n",
    "    prompt = (f\"Food Node {food_id}: The food description is: {food_desc}. \"\n",
    "              f\"This food belongs to the category: {food_category}. \"\n",
    "              f\"The ingredients in this food are: {ingredients}.\")\n",
    "    return pd.Series({'food_prompt_text': prompt})\n",
    "\n",
    "# Apply the function to each group\n",
    "df_prompts = df_fndds.groupby('food_id').apply(create_food_prompt).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_food = df_food.merge(df_prompts, on='food_id', how='left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_food[['food_id', 'food_prompt_text', 'prompt_health']].to_csv('../processed_data/food_prompt.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dictionaries for user prompts\n",
    "user_prompt_dict = df_user.set_index('SEQN')['prompt'].to_dict()\n",
    "user_prompt_health_dict = df_user.set_index('SEQN')['prompt_health'].to_dict()\n",
    "\n",
    "# Create dictionaries for food prompts\n",
    "food_prompt_dict = df_food.set_index('food_id')['food_prompt_text'].to_dict()\n",
    "food_prompt_health_dict = df_food.set_index('food_id')['prompt_health'].to_dict()\n",
    "\n",
    "# Initialize prompt features with empty strings\n",
    "user_prompt = [\"\"] * data['user'].num_nodes\n",
    "user_prompt_health = [\"\"] * data['user'].num_nodes\n",
    "food_prompt = [\"\"] * data['food'].num_nodes\n",
    "food_prompt_health = [\"\"] * data['food'].num_nodes\n",
    "\n",
    "# Assign prompts to the appropriate nodes\n",
    "for i in range(data['user'].num_nodes):\n",
    "    seqn = data['user'].node_id[i].item()\n",
    "    if seqn in user_prompt_dict:\n",
    "        user_prompt[i] = user_prompt_dict[seqn]\n",
    "        user_prompt_health[i] = user_prompt_health_dict[seqn]\n",
    "\n",
    "for i in range(data['food'].num_nodes):\n",
    "    food_id = data['food'].node_id[i].item()\n",
    "    if food_id in food_prompt_dict:\n",
    "        food_prompt[i] = food_prompt_dict[food_id]\n",
    "        food_prompt_health[i] = food_prompt_health_dict[food_id]\n",
    "\n",
    "# Add prompt features to HeteroData\n",
    "data['user'].prompt = user_prompt\n",
    "data['user'].prompt_health = user_prompt_health\n",
    "data['food'].prompt = food_prompt\n",
    "data['food'].prompt_health = food_prompt_health"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "HeteroData(\n",
       "  user={\n",
       "    x=[8170, 38],\n",
       "    node_id=[8170],\n",
       "    num_nodes=8170,\n",
       "    tags=[8170, 14],\n",
       "    prompt=[8170],\n",
       "    prompt_health=[8170],\n",
       "  },\n",
       "  food={\n",
       "    x=[6769, 66],\n",
       "    node_id=[6769],\n",
       "    num_nodes=6769,\n",
       "    tags=[6769, 14],\n",
       "    prompt=[6769],\n",
       "    prompt_health=[6769],\n",
       "  },\n",
       "  (user, eats, food)={\n",
       "    edge_index=[2, 314224],\n",
       "    edge_label_index=[2, 122009],\n",
       "  }\n",
       ")"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(data, benchmark_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "FRS",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
